#!/usr/bin/env python3

import functools
import itertools
import os
import random
from collections import defaultdict, deque
from copy import deepcopy
from pathlib import Path
from typing import Callable, Dict, List, Optional, Sequence

import numpy as np
import torch
import wandb
from rpi import logger
from rpi.agents.base import Agent
from rpi.agents.mamba import (ActivePolicySelector, ActiveStateExplorer,
                                MaxValueFn, ValueEnsemble)
from rpi.agents.ppo import update_critic_ensemble, update_state_pred_ensemble
from rpi.agents.mamba import StatePredictorEnsemble
from rpi.agents.ppo import PPOAgent
from rpi.helpers import set_random_seed, to_torch
from rpi.helpers.data import flatten
from rpi.helpers.env import rollout_single_ep, rollout, roll_in_and_out_lops, roll_in_and_out_mamba
from rpi.helpers.initializers import ortho_init
from rpi.nn.empirical_normalization import EmpiricalNormalization
from rpi.policies import (GaussianHeadWithStateIndependentCovariance,
                            SoftmaxCategoricalHead)
from rpi.scripts.sweep.default_args import Args
from rpi.value_estimations import (
    _attach_advantage_and_value_target_to_episode,
    _attach_log_prob_to_episodes, _attach_return_and_value_target_to_episode,
    _attach_value_to_episodes)
from torch import nn
import time

from .train import tolist
from .train import Evaluator


def inspect_statepred(make_env: Callable, experts: List[Agent], state_predictors: List[nn.Module], state_pred_optimizers: List[torch.optim.Optimizer], evaluator: Evaluator, max_episode_len: int):
    """
    Only train each expert's next-state prediction model, with the rollouts given by experts

    Evaluation:
    - Plot the predicted mean and stddev of a set of fixed states

    Expectation:
    - mean should converge to somewhere
    - stddev should converge to some value, in the deterministic case, it should go to zero
    """
    env = make_env()

    # For each expert k, collect data D^k by rolling out pi^k
    expert_rollouts = [deque(maxlen=Args.expert_buffer_size) for _ in experts]  # 100 for CartPole, DIP, 2 for HalfCheetah and Ant
    for _ in range(Args.pret_num_rollouts):
        for expert_idx, expert in enumerate(experts):
            episode = rollout_single_ep(env, functools.partial(expert.act, mode=Args.deterministic_experts), max_episode_len)

            # if Args.expert_tgtval == 'monte-carlo':
            #     _attach_return_and_value_target_to_episode(episode, gamma)
            # elif Args.expert_tgtval == 'gae':
            #     _attach_value_to_episodes(experts[expert_idx].vfn, episode, obs_normalizer=experts[expert_idx].obs_normalizer)
            #     _attach_advantage_and_value_target_to_episode(episode, gamma, lambd)
            #     # TODO: update value function according to it
            # else:
            #     raise ValueError(f'Unknown method: {Args.expert_tgtval}')
            expert_rollouts[expert_idx].append(episode)

    ref_stateacts = []
    for expert_idx in range(len(experts)):
        first_ep = expert_rollouts[expert_idx][0]
        rand_inds = np.random.choice(len(first_ep), size=(5, ))
        transitions = [first_ep[idx] for idx in rand_inds]
        ref_stateacts += [(trans['state'], trans['action']) for trans in transitions]

    # Eval before training
    logs = evaluator.inspect_state_predictor(state_predictors=state_predictors, ref_stateacts=ref_stateacts)
    wandb.log({**logs, 'step': 0})

    # Update state prediction networks
    for expert_idx, (state_predictor, state_pred_optimizer) in enumerate(zip(state_predictors, state_pred_optimizers)):
        expert_k_transitions = flatten(expert_rollouts[expert_idx])

        state_predictor.obs_normalizer.experience(to_torch([tr['state'] for tr in expert_k_transitions]))
        _, loss_state_pred_history, _ = update_state_pred_ensemble(state_predictor,
                                                                   expert_k_transitions,
                                                                   state_pred_optimizer,
                                                                   num_epochs=Args.pret_num_epochs,
                                                                   batch_size=Args.batch_size,
                                                                   std_from_means=True)
        for i, loss in enumerate(loss_state_pred_history):
            wandb.log({
                f'pretrain-statepred/loss-{expert_idx}': loss,
                f'pretrain-statepred/num_transitions-{expert_idx}': len(expert_k_transitions),
                'pret-step': i,
            })

    # Eval after training
    logs = evaluator.inspect_state_predictor(state_predictors=state_predictors, ref_stateacts=ref_stateacts)
    wandb.log({**logs, 'step': 1})


def main():
    import gym
    from rpi.agents.mamba import MambaAgent
    from .train import Factory, get_expert

    num_train_steps = Args.num_train_steps

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    gamma = Args.gamma  # 0.995
    lmd = Args.lmd  # 0.97
    load_expert_step = Args.load_expert_step

    set_random_seed(Args.seed)

    def _make_env(env_name='DartCartPole-v1', test=False, default_seed=0):
        from rpi.helpers import env
        seed = default_seed if not test else 42 - default_seed
        if env_name.startswith('dmc'):
            extra_kwargs = {'task_kwargs': {'random': seed}}
        else:
            extra_kwargs = {}
        return env.make_env(env_name, seed=seed, **extra_kwargs)

    # def _make_dm_env(test=False, env_name="dmc:Cheetah-run-v1"):
    #     seed = Args.seed if not test else 42 - Args.seed
    #     env = gym.make(env_name, task_kwargs={'random': seed})
    #     from dm_control import suite
    #     assert (domain, task) in suite.BENCHMARKING, f'unknown domain or task specified.\nlist of tasks: {suite.BENCHMARKING}'
    #     seed = Args.seed if not test else 42 - Args.seed
    #     env = suite.load(domain_name=domain, task_name=task, task_kwargs={'random': seed})

        # return env

    make_env = lambda *args, **kwargs: _make_env(Args.env_name, *args, **kwargs)  # TEMP

    test_env = make_env()
    state_dim = test_env.observation_space.low.size


    if isinstance(test_env.action_space, gym.spaces.Box):
        # Continuous action space
        act_dim = test_env.action_space.low.size
        policy_head = GaussianHeadWithStateIndependentCovariance(
            action_size=act_dim,
            var_type="diagonal",
            var_func=lambda x: torch.exp(2 * x),  # Parameterize log std
            var_param_init=0,  # log std = 0 => std = 1
        )
    else:
        # Discrete action space (assuming categorical)
        act_dim = test_env.action_space.n
        policy_head = SoftmaxCategoricalHead()

    logger.info('obs_dim', state_dim)
    logger.info('act_dim', act_dim)

    pi = Factory.create_pi(state_dim, act_dim, policy_head=policy_head)

    obs_normalizer = EmpiricalNormalization(state_dim, clip_threshold=5)
    obs_normalizer.to('cuda')

    # Loading: experts

    experts=[]
    state_predictors = []
    state_pred_optimizers = []
    for idx in load_expert_step:
        expert = get_expert(state_dim, act_dim, deepcopy(policy_head), Path(Args.experts_dir) / test_env.unwrapped.spec.id.lower() / f'step_{idx:06d}.pt',
                        obs_normalizer=None if Args.use_expert_obsnormalizer else obs_normalizer)

        state_predictor = StatePredictorEnsemble(lambda: Factory.create_state_nn(state_dim, act_dim),
                                                 num_state_nns=Args.num_expert_vfns,
                                                 state_dim=state_dim,
                                                 act_dim=act_dim,
                                                 obs_normalizer=EmpiricalNormalization(state_dim, clip_threshold=5),
                                                 std_from_means=Args.std_from_means)
        state_predictor.to('cuda')
        for state_nn in state_predictor.nns:
            ortho_init(state_nn[0], gain=Args.expert_vfn_gain)
            ortho_init(state_nn[2], gain=Args.expert_vfn_gain)
            ortho_init(state_nn[4], gain=Args.expert_vfn_gain)

        experts.append(expert)
        state_predictors.append(state_predictor)
        state_pred_optimizers.append(
            torch.optim.Adam(state_predictor.parameters(), lr=1e-3)
        )

    vfn = MaxValueFn([expert.vfn for expert in experts], obs_normalizers=[expert.obs_normalizer for expert in experts])


    if Args.algorithm == 'pg-gae':
        vfn = Factory.create_vfn(state_dim)
        optimizer = torch.optim.Adam(list(pi.parameters()) + list(vfn.parameters()), lr=1e-3, betas=(0.9, 0.99))
        learner = PPOAgent(pi, vfn, optimizer, obs_normalizer, gamma=Args.gamma, lambd=Args.lmd)
        vfn.to(device)
    else:
        optimizer = torch.optim.Adam(pi.parameters(), lr=1e-3, betas=(0.9, 0.99))
        learner = MambaAgent(pi, vfn, optimizer, obs_normalizer, gamma=gamma, lambd=lmd, use_ppo_loss=Args.use_ppo_loss)
    pi.to(device)
    learner.to(device)

    max_episode_len = 1000
    evaluator = Evaluator(make_env, max_episode_len=max_episode_len)
    inspect_statepred(make_env, experts, state_predictors, state_pred_optimizers, evaluator, max_episode_len=max_episode_len)


if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("sweep_file", help="sweep file")
    parser.add_argument("-l", "--line-number", type=int, help="sweep file")
    args = parser.parse_args()

    # Obtain kwargs from Sweep
    from params_proto.hyper import Sweep
    sweep = Sweep(Args).load(args.sweep_file)
    kwargs = list(sweep)[args.line_number]

    Args._update(kwargs)

    num_gpus = 4
    cvd = args.line_number % num_gpus
    # cvd = np.random.choice(Args.available_gpu)

    os.environ['CUDA_VISIBLE_DEVICES'] = str(cvd)

    sweep_basename = os.path.splitext(os.path.basename(args.sweep_file))[0]
    wandb.login()
    wandb.init(
        # Set the project where this run will be logged
        project='alops-inspect-state-pred',
        group=sweep_basename,
        config=vars(Args),
    )
    wandb.run.name= Args.algorithm+"-s"+str(Args.seed)+"-l"+str(Args.lmd)+"-e"+str(Args.load_expert_step) + '-d' + str(Args.deterministic_experts)

    if Args.algorithm == "lops":
        wandb.run.name=wandb.run.name+"-sig"+str(Args.ase_sigma)
    main()
    # wandb.agent(f' anoymous- anoymous/lightrl/{args.sweep_id}', function=main)
